import os
import cv2
import pdb
import copy
import random
import argparse
import numpy as np
from tqdm import tqdm
from glob import escape, glob
from imantics import Polygons, Mask
from matplotlib.path import Path
from operator import itemgetter
from shapely.geometry import Polygon
random.seed(123)

PALLTE = np.array([[255, 255, 255], [128, 0, 0], [0, 128, 0], [0, 0, 128], [128, 128, 0], \
         [128, 0, 128], [0, 128, 128], [128, 128, 128], [255, 0, 0], [0, 255, 0], [0, 0, 255]])


def kmeans(image, k):
    '''k means algorithm to seperate different region of OF
    
    Args:
        image: Predicted optical flow
        k: number of grouped center
    
    Returns:
        converted optical flow using given PALLTE
    '''
     # segmented_image = cv2.Laplacian(image, cv2.CV_64F)
    pixel_values = image.reshape((-1, image.shape[-1]))
    pixel_values = np.float32(pixel_values)

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1)
    _, labels, (centers) = cv2.kmeans(pixel_values, k, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    
    labels = labels.flatten()
    segmented_image = PALLTE[labels.flatten()]
    segmented_image = segmented_image.reshape([image.shape[0], image.shape[1], 3])
    return segmented_image

def get_disconnected_region(mask):
    '''Function to seperate multiple disconnected regions of a mask

    Args:
        mask: filtered mask in seg2geo function
    
    Returns:
        a list contains multiple disconnected objects in the mask
    '''
    mask_list = list()
    ret, labels = cv2.connectedComponents(mask)
    for idx in range(1, ret):
        mask = np.array(labels, dtype=np.uint8)
        mask[labels == idx] = 255
        mask_list.append(mask)
    return mask_list

def filter_mask(mask, kernel_size=3, DO_MORPH=False, MORPH_ITER=0):
    '''Function to filter the conveted mask from OF
    
    Args:
        mask: conveted mask from optical flow
        kernel_size: kernel size for the kernel in cv2.morphologyEx
    
    Returns:
        filtered mask using cv2.morphologyEx
    '''
    if DO_MORPH:
        kernel = np.ones((kernel_size, kernel_size), np.uint8)
        mask_ = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=MORPH_ITER)
    else:
        mask_ = mask
    return mask_

def get_hull_of(of, hull):
    '''Function to get the mean optical flow within a polygon
    '''
    of_x, of_y = of[:, :, 0].flatten(), of[:, :, 1].flatten()
    of_h, of_w = of.shape[:2]
    x, y = np.meshgrid(np.arange(of_w), np.arange(of_h))  
    x, y = x.flatten(), y.flatten()
    points = np.vstack((x,y)).T 
    hull_list = hull.tolist()
    p = Path(hull_list) 
    grid = p.contains_points(points)
    if not (True in grid):
        return [0, 0]
    else:
        x_mean = np.median(of_x[grid])
        y_mean = np.median(of_y[grid])
        return [x_mean, y_mean] #TODO 


def seg2geo(image, of, param_dict, image_path, mask_dir=None):
    '''Function used to convert clusterred OF to geometric symbol format

    Args:
        image: converted optical flow returnned by kmeans function
        k: number of grouped center

    Returns:
        TBD
    '''
    k = param_dict["K"]
    MORPH_KS = param_dict["MORPH_KS"]
    DO_MORPH = param_dict["DO_MORPH"]
    MORPH_ITER = param_dict["MORPH_ITER"]
    FILTER_IOU_THRES = param_dict["FILTER_IOU_THRES"]
    FILTER_AREA_THRES = param_dict["FILTER_AREA_THRES"]
    FILTER_OF_THRES = param_dict["FILTER_OF_THRES"]
    IMAGE_SHAPE = param_dict["IMAGE_SHAPE"]
    IMAGE_BOX = np.array([[0, 0], [0, IMAGE_SHAPE[1]], \
                        [IMAGE_SHAPE[0], IMAGE_SHAPE[1]], [IMAGE_SHAPE[0], 0]])
    segmented_masks = list()
    for idx in range(k):
        mask_value = PALLTE[idx, :]
        mask = (image == mask_value).all(-1).astype(np.uint8) * 255
        
        mask = filter_mask(mask, MORPH_KS, DO_MORPH, MORPH_ITER)
        segmented_masks.append(get_disconnected_region(mask))

    hull_dict = dict()
    for class_idx, seg_mask in enumerate(segmented_masks):
        for region_idx, region_mask in enumerate(seg_mask):
            IDX = os.path.join(os.path.basename(image_path), \
                "{}_{}".format(class_idx, region_idx))
            
            ret, threshed_img = cv2.threshold(region_mask,
                        127, 255, cv2.THRESH_BINARY)
            contours, hierarchy = cv2.findContours(threshed_img, 2, 1)
            cnt = contours[0]
            hull = cv2.convexHull(cnt).squeeze(1)
            if len(hull) <= 3:
                continue
        
            if get_iou(hull, IMAGE_BOX) >= FILTER_IOU_THRES:
                continue

            if get_hull_area(hull) <= FILTER_AREA_THRES:
                continue

            mean_of = get_hull_of(of, hull)
            mean_of = np.array(mean_of, dtype=np.float16)

            if np.sum(np.abs(mean_of)) <= FILTER_OF_THRES:
                continue
            
            hull_dict[IDX] = dict()
            hull_dict[IDX]["hull"] = hull.astype(np.float16)
            hull_dict[IDX]["center"] = np.mean(hull, axis=0)
            hull_dict[IDX]["mean_of"] = mean_of
            hull_dict[IDX]["next_obj"] = None
            ###########################################################
            if mask_dir is not None:
                # length = len(hull)
                # region_image = cv2.cvtColor(region_mask, cv2.COLOR_GRAY2BGR)
                # for hull_idx in range(len(hull)):
                #     cv2.line(region_image, tuple(hull[hull_idx]), tuple(hull[(hull_idx+1)%length]), (0,255,0), 1)
                basename = os.path.basename(image_path).split(".")[0]
                tmp_path = path3_tmp
                
                os.makedirs(tmp_path, exist_ok=True)
                tmp_path = tmp_path + basename
                save_path = tmp_path + "_{}_{}_contour.png".format(class_idx, region_idx)
                cv2.imwrite(save_path, region_mask)
            ###########################################################
    return hull_dict


def flow2hull(image_path, param_dict, mask_dir):
    if ".png" in image_path:
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    elif ".flo" in image_path:
        f = open(image_path, 'rb')
        magic = np.fromfile(f, np.float32, count=1)
        data2d = None
        if 202021.25 != magic:
            print('Magic number incorrect. Invalid .flo file')
        else:
            w = np.fromfile(f, np.int32, count=1)
            h = np.fromfile(f, np.int32, count=1)
            data2d = np.fromfile(f, np.float32, count=int(2 * w * h))
            of = np.resize(data2d, (h[0], w[0], 2))
        f.close()
    else:
        raise RuntimeError("image type error")
    segmented_image = kmeans(of, param_dict["K"])
    hull_dict = seg2geo(segmented_image, of, param_dict, image_path, mask_dir=mask_dir)
    return hull_dict


def get_iou(hull_1, hull_2):
    hull_1 = tuple(map(tuple, hull_1))
    p1 = Polygon(hull_1)

    hull_2 = tuple(map(tuple, hull_2))
    p2 = Polygon(hull_2)

    inter_area = p1.intersection(p2).area
    union_area = p1.union(p2).area
    return inter_area / (union_area + 1e-5)


def get_hull_area(hull):
    p1 = Polygon(hull)
    return p1.area


def get_max_overlap_obj(hull_1, hull_dict):
    max_iou = 0.
    mask_name = None

    for m_name in hull_dict:
        hull_2 = hull_dict[m_name]["hull"]
        iou = get_iou(hull_1, hull_2)
        if iou > max_iou:
            max_iou = iou
            mask_name = m_name
    return [max_iou, mask_name]



def draw_trace(hull_dict_all, vis_path, IMAGE_SHAPE):
    mask = np.zeros(IMAGE_SHAPE)
    for f_path, hull_dict in hull_dict_all.items():
        for m_name, hull_info in hull_dict.items():
            if hull_info["next_obj"] != None:
                pdb.set_trace()
                # cv2.line(mask, tuple(hull_info['center'].squeeze(0)), tuple(hull_info['center'].squeeze(0)), (0,255,0), 2)


def draw_symbol(hull_dict_all, vis_dir, IMAGE_SHAPE):
    os.makedirs(vis_dir, exist_ok=True)
    flow_path_list = [*hull_dict_all]
    for idx, flow_path in enumerate(flow_path_list[:-1]):
        mask = np.zeros(IMAGE_SHAPE)
        hull_dict = hull_dict_all[flow_path]


        for m_name, hull_info in hull_dict.items():

            hull = hull_info['hull'].astype(np.int16)
            length = hull.shape[0]
            color = random.choice(PALLTE.tolist())
            for hull_idx in range(len(hull)):
                try:
                    cv2.line(mask, tuple(hull[hull_idx]), tuple(hull[(hull_idx+1)%length]), color, 1)
                except:
                    pdb.set_trace()
        save_name = os.path.basename(flow_path).split(".")[0]
        cv2.imwrite(os.path.join(vis_dir, "{}.png".format(save_name)), mask)


def draw_hull(hull, vis_path, IMAGE_SHAPE):
    os.makedirs(os.path.dirname(vis_path), exist_ok=True)
    mask = np.zeros(IMAGE_SHAPE)
    length = hull.shape[0]
    color = random.choice(PALLTE.tolist())
    for hull_idx in range(len(hull)):
        try:
            cv2.line(mask, tuple(hull[hull_idx]), tuple(hull[(hull_idx+1)%length]), color, 1)
        except:
            raise RuntimeError("hull is:", hull) 
    cv2.imwrite(vis_path, mask)


def draw_sysobjs(sys_objs, vis_path, IMAGE_SHAPE, protagonist_):
    os.makedirs(os.path.dirname(vis_path), exist_ok=True)
    mask = np.zeros(IMAGE_SHAPE)
    for inst_name, inst_val in sys_objs.items():
        if inst_val["status"] not in ["post", "alive"]:
            continue

        if get_hull_area(inst_val["hull"])>protagonist_ and (inst_val['status'] in ["post", "alive"]):
            print('in debug----2, protagonist not_supp, name=', inst_val['inst_name'], 'life = ',inst_val['life'] , inst_val['to_remove'], "status:", inst_val['status'] )

        hull = inst_val["hull"].astype(np.int16)
        length = hull.shape[0]
        color = PALLTE.tolist()[inst_val["role_id"]]
        for hull_idx in range(len(hull)):
            try:
                cv2.line(mask, tuple(hull[hull_idx]), tuple(hull[(hull_idx+1)%length]), color, 1)
            except:
                pdb.set_trace()
    cv2.imwrite(vis_path, mask)


def assign_new_obj(sys_objs, obj_name, **kwargs):
    sys_objs[obj_name] = dict()
    for k, v in kwargs.items():
        sys_objs[obj_name][k] = v
    return sys_objs


def conver_obj_attrib(sys_objs, obj_name, **kwargs):
    for k, v in kwargs.items():
        sys_objs[obj_name][k] = v
    return sys_objs


def intersect_value_2d(obj_a, obj_b):
    p_a = Polygon(obj_a["hull"])
    p_b = Polygon(obj_b["hull"])
    inter_area = p_a.intersection(p_b).area
    return [inter_area/p_a.area, inter_area/p_b.area]

def renew_convex_hull(obj_list):
    new_hull = list()
    new_of = list()
    for obj in obj_list:
        new_hull.append(obj["hull"])
        new_of.append(np.expand_dims(obj["mean_of"], 1))
    return np.concatenate(new_hull), np.mean(np.concatenate(new_of, axis=1), axis=1)


def inertia_update(obj):
    mean_of = obj["mean_of"]
    obj["hull"] += mean_of
    obj["mean_of"] = np.zeros_like(mean_of)
    return obj


def get_role_template(template_dir):
    role_dict = dict()
    image_paths = glob(os.path.join(template_dir, "*.png"))
    for img_p in image_paths:
        mask = cv2.imread(img_p)
        if len(mask.shape) != 2:
            mask = mask[:, :, 0]
        role_name = os.path.basename(img_p).split(".png")[0]
        ret, threshed_img = cv2.threshold(mask,
                        127, 255, cv2.THRESH_BINARY)
        contours, hierarchy = cv2.findContours(threshed_img, 2, 1)
        cnt = contours[0]
        hull = cv2.convexHull(cnt).squeeze(1)
        role_dict[role_name] = hull
    return role_dict
    

def do_tracking_v2(sys_objs, hull_dict_all, flow_path, p_dict):

    # ----------- hparams -----------
    th_a_inside_b = p_dict["TH_A_INSIDE_B"]
    th_b_inside_a = p_dict["TH_B_INSIDE_A"]
    th_supp = p_dict["TH_SUPP"]
    life_prev = p_dict["LIFE_PREV"]# a new of_obj must CONTINUE to present for consecutive life_prev frames to be added to live sys_obj (life = inf);  this value MUST be > -1
    life_post = p_dict["LIFE_POST"] # a sys_obj only need to be detected ONCE in a consecutive life_post frames to be bring back to life = inf
    protagonist_ = p_dict["protagonist_"]
    ######### update system's object's location ########
    # for inst_name_y in list(sys_objs.keys()):
    #     objy = sys_objs[inst_name_y]



    for inst_name_y in list(sys_objs.keys()):
        objy = sys_objs[inst_name_y]
        objy["covered_list"] = []


    # ----------- iterate each instance in 1 frame -----------
    of_objs = hull_dict_all[os.path.basename(flow_path)]
    for inst_name_f, inst_val_f in of_objs.items():
        objf = {
            "hull": inst_val_f["hull"],
            "mean_of": inst_val_f["mean_of"], 
            "has_covered_sys_obj": None,
            "been_catched": False,
            "to_remove": False,
            "inst_name": inst_name_f,
            "father_sys_obj": None,
            "is_tiny": get_hull_area(inst_val_f["hull"])<p_dict["TH_TINY_OBJ_AREA"]
                }
        # assign_new_obj(of_objs, inst_name_f, **kw_args)
        for inst_name_y in list(sys_objs.keys()):
            objy = sys_objs[inst_name_y]
            vints = intersect_value_2d(objf, objy)

            if vints[0]>th_a_inside_b:#sys conver of, or matched
                objf["been_catched"] = True
                # objy["covered_list"] = [objf] #TODO or append()?  wenqing: please, use append here! I'm sure.
                objy["covered_list"].append(objf)
            elif vints[0]<=th_a_inside_b and vints[1]>th_b_inside_a: # new frame object cover system object
                objf["been_catched"] = True
                # if new object does not contain system object
                if objf["has_covered_sys_obj"] is None:
                    new_sys_obj = copy.deepcopy(objf)
                    new_sys_obj["status"] = 'prev'
                    new_sys_obj["life"] = life_prev
                    new_sys_obj["init"] = True
                    new_sys_obj["covered_list"] = [objy]
                    new_sys_obj["child_sys_objs"] = [objy]
                    sys_objs[inst_name_f]  = new_sys_obj
                    objf["has_covered_sys_obj"] = new_sys_obj
                # if new object already contain system object
                else:
                    new_sys_obj = objf["has_covered_sys_obj"]
                    new_sys_obj["covered_list"].append(objy)
                    new_sys_obj["child_sys_objs"].append(objy)
            if vints[0]<=th_a_inside_b and vints[1]<=th_b_inside_a:
                if objf["is_tiny"] and objy['is_tiny'] and distance_of(objf, objy)<p_dict["TH_TINY_OBJ_DIST"]:
                    objy["covered_list"] = [objf]  # wenqing: use equal instead of append here. I'm sure. 
                else: # non_relavent
                    pass

            else:
                print("this case is impossible?? vints:", vints)
        #if new object appears isolately
        if not objf["been_catched"]:
            new_sys_obj = copy.deepcopy(objf)
            new_sys_obj["status"] = 'prev'
            new_sys_obj["life"] = life_prev
            new_sys_obj["init"] = True
            new_sys_obj["covered_list"] = []
            new_sys_obj["child_sys_objs"] = []
            new_sys_obj["is_tiny"] = objf["is_tiny"]
            sys_objs[inst_name_f] = new_sys_obj
    # ----------- post process: change sys_objs members -----------
    # for objy in sys_objs:
    for inst_name_y in list(sys_objs.keys()):
        objy = sys_objs[inst_name_y]

        if get_hull_area(objy["hull"])>protagonist_ and (objy['status'] in ["post", "alive"]):
            print('in debug----1, protagonist not_supp, name=', objy['inst_name'], 'life = ',objy['life'] , objy['to_remove'], "status:", objy['status'] )


        sum_covered_area = sum([get_hull_area(obj["hull"]) for obj in objy["covered_list"]])
        # print("percent:", sum_covered_area/get_hull_area(objy["hull"])< th_supp)
        


        if (sum_covered_area/get_hull_area(objy["hull"])) < th_supp:  # support is NOT enough
            if get_hull_area(objy["hull"])>protagonist_ and (objy['status'] in ["post", "alive"]):
                print('in debug#1, protagonist not_supp, name=', objy['inst_name'], 'life = ',objy['life'] , objy['to_remove'], "status:", objy['status'] )

            objy["is_supported"] = False
            if objy["status"]=='prev':
                if "init" in objy.keys():
                    if objy["init"]:
                        objy["life"] += 1
                        objy["init"] = False
                    else:
                        objy["life"] -= 1
            elif objy["status"]=='post':
                objy["life"] -= 1
            
        else:  # support is large enough
            if get_hull_area(objy["hull"])>protagonist_ and (objy['status'] in ["post", "alive"]):
                print('in debug ????1, protagonist not_supp, name=', objy['inst_name'], 'life = ',objy['life'] , objy['to_remove'], "status:", objy['status'] )


            objy["is_supported"] = True
            if objy["status"]=='prev':
                objy["life"] += 1
            

        # ----------- change life and status -----------
        if objy["status"]=='prev':
            
            if objy["life"] < life_prev:
                objy["to_remove"] = True
            if objy["life"]>=0:
                objy["status"] = 'alive'
                objy["life"] = float('inf')
                if objy["child_sys_objs"] != []:  # mother obj becomes stable, hence delete children
                    for child in objy["child_sys_objs"]:
                        if child["inst_name"] in sys_objs.keys():
                            sys_objs.pop(child["inst_name"])

        elif objy["status"]=='post':
            if objy["life"] <= 0:
                objy["to_remove"] = True
                if get_hull_area(objy["hull"])>protagonist_ and (objy['status'] in ["post", "alive"]):
                    print('in debug#2, protagonist not_supp, name=', objy['inst_name'], 'life = ',objy['life'] , objy['to_remove'], "status:", objy['status'] )
            else:
                if objy["is_supported"]:
                    objy["status"] = 'alive'
                    objy["life"] = float('inf')

        elif (objy["status"]=='alive') and (not objy["is_supported"]):
            objy["status"] = 'post'
            objy["life"] = life_post 
            if get_hull_area(objy["hull"])>protagonist_ and (objy['status'] in ["post", "alive"]):
                    print('in debug#0, protagonist not_supp, name=', objy['inst_name'], 'life = ',objy['life'] , objy['to_remove'], "status:", objy['status'] )



        # ----------- to_remove -----------
        if objy["to_remove"]:
            if get_hull_area(objy["hull"])>protagonist_ and (objy['status'] in ["post", "alive"]):
                    print('in debug#3, protagonist not_supp, name=', objy['inst_name'], 'life = ',objy['life'] , objy['to_remove'], "status:", objy['status'] )
            if objy["inst_name"] in sys_objs.keys():
                sys_objs.pop(objy["inst_name"])
            
    for inst_name_y in list(sys_objs.keys()):
        objy = sys_objs[inst_name_y]
        # ----------- renew position/shape -----------
        if objy["is_supported"]:
            new_hull, new_of = renew_convex_hull(objy["covered_list"])
            objy["convex_hull"], objy["mean_of"] = new_hull, new_of
            
    for inst_name_y in list(sys_objs.keys()):
        objy = sys_objs[inst_name_y]
        objy = inertia_update(objy)
    print("sys objs length:", len(list(sys_objs.keys())))

    return sys_objs


def get_obj_rolename(sys_objs_, role_dict_, p_dict):
    sys_objs = copy.deepcopy(sys_objs_)
    role_dict = copy.deepcopy(role_dict_)
    for inst_name_y in list(sys_objs.keys()):
        objy = sys_objs[inst_name_y]
        hull = copy.deepcopy(objy["hull"].astype(np.float64))
        hull_ = hull - np.mean(hull, axis=0)
        max_iou, max_name = 0, 0
        for role_name in list(role_dict.keys()):
            hull_temp = copy.deepcopy(role_dict[role_name].astype(np.float64))
            hull_temp_ = hull_temp - np.mean(hull_temp, axis=0)
            iou = get_iou(hull_, hull_temp_)
            if max_iou < iou:
                max_iou = iou
                max_name = role_name
        id = p_dict["ROLE_ID_MAP"][max_name]
        objy["role_id"] = id
    return sys_objs

def distance_of(objf, objy, dist_type="L1"):
    objf_center = np.mean(objf["hull"], axis=0)
    objy_center = np.mean(objy["hull"], axis=0)
    if dist_type == "L1":
        dist = np.sum(np.abs(objf_center - objy_center))
    else:
        raise NotImplementedError
    return dist


def get_obj_name(flow_path, inst_name):
    return os.path.join(os.path.basename(flow_path), inst_name)


def cvt_flow2symbol(flow_dir, template_dir, \
    param_dict, game, mask_dir=None, symbol_dir=None):
    p_dict = param_dict
    '''Iterate each frame
    '''
    role_dict = get_role_template(template_dir)
    flow_path_list = glob(os.path.join(flow_dir, "*.flo"))
    flow_path_list = sorted(flow_path_list)

    ###### save all instance information into a dict######
    hull_dict_all = dict()
    count = 0
    for flow_path in tqdm(flow_path_list):
        if count >= 100:
            break
        count += 1
        hull_dict = flow2hull(flow_path, param_dict, mask_dir=mask_dir)
        hull_dict_all[os.path.basename(flow_path)] = hull_dict
    
    ###### init sys_objs and iterate each frame ######
    sys_objs = dict()
    flow_path_list = [*hull_dict_all]
    for idx, flow_path in enumerate(flow_path_list[:-1]):
        print("flow_path=>", os.path.basename(flow_path))
        frame_curr = hull_dict_all[flow_path]
        ##### initialize system objects #####
        if idx == 0:
            for inst_name, inst_val in frame_curr.items():
                kw_args = {
                    "hull": inst_val["hull"],
                    "mean_of": inst_val["mean_of"], 
                    "life": p_dict["LIFE_PREV"],
                    "status": 'prev',
                    "covered_list": list(),
                    "father_sys_obj": None,
                    "to_remove": False,
                    "is_supported": False,
                    "child_sys_objs": list(),
                    "inst_name": inst_name,
                    "is_tiny": get_hull_area(inst_val["hull"])<p_dict["TH_TINY_OBJ_AREA"],
                     }
                assign_new_obj(sys_objs, inst_name, **kw_args)
        ##### iterably update system objects #####
        else:
            sys_objs = do_tracking_v2(sys_objs, hull_dict_all, flow_path, p_dict)
    
        sys_objs = get_obj_rolename(sys_objs, role_dict, p_dict)
        print("sys_objs.keys:", len(list(sys_objs.keys())))
        if symbol_dir is not None:
            draw_sysobjs(sys_objs, "{}/{}.png".format( \
                path2,os.path.basename(flow_path).split(".")[0]), p_dict["IMAGE_SHAPE"], p_dict["protagonist_"])

if __name__ == "__main__":

    path1_in = "/home/zhiwen/projects/flownet2-pytorch/output/Airstriker-Genesis"
    path2 = "/home/zhiwen/projects/flownet2-pytorch/output/draw_sysobjs"
    path3_tmp = "/home/zhiwen/projects/flownet2-pytorch/output/draw_mask/"
    path4 = "/home/zhiwen/projects/flownet2-pytorch/output/templates/Airstriker-Genesis/"

    # path1_in = "/Users/wenqzhen/Desktop/debug_OF/Airstriker-Genesis"
    # path2 = "/Users/wenqzhen/Desktop/debug_OF/out/out1"
    # path3_tmp = "/Users/wenqzhen/Desktop/debug_OF/out/out2"
    # path4 = "/Users/wenqzhen/Desktop/debug_OF/templates/Airstriker-Genesis"

    airstriker_genesis = {
        "protagonist_": 280, 
        "K": 4,
        "ROLE_ID_MAP": {"protagonist":0, "enemy":1, "bullet":2},
        "DO_MORPH": True,
        "MORPH_KS": 2,
        "MORPH_ITER": 1,
        "TH_A_INSIDE_B": 0.7,
        "TH_B_INSIDE_A": 0.7,
        "LIFE_PREV": -3,
        "LIFE_POST": 5,
        "TH_SUPP": 0.4,
        "IOU_THRES": 0.5,
        "FILTER_IOU_THRES": 0.1, # obj_area/image_box_area, smaller is more strict
        "FILTER_OF_THRES": 0.1, # mean OF thres 
        "FILTER_AREA_THRES": 1, # larger is more strict /pixel
        "INSIDE_THRES": 0.5,
        "IMAGE_SHAPE": (224, 320, 3),
        "TH_TINY_OBJ_AREA": 20,
        "TH_TINY_OBJ_DIST": 10,
    }

    cvt_flow2symbol(flow_dir=path1_in, template_dir=path4, \
        param_dict=airstriker_genesis, game="Airstriker-Genesis", mask_dir=path3_tmp, symbol_dir=path2)